#include <stdlib.h>
#include <assert.h>
#include <math.h>

#include "io_psf.h"
#include "io_charmm_inp.h"
#include "esmd_types.h"
#include "cell.h"
#include "bonded.h"
#include "memory_cpp.hpp"
#define CMP_INT(a, b) (((a) > (b)) - ((a) < (b)))

static int cmp_bond(const void *va, const void *vb) {
  const int *a = (const int *)va;
  const int *b = (const int *)vb;
  return CMP_INT(a[0], b[0]) * 3 + CMP_INT(a[1], b[1]);
}

static int cmp_impr(const void *va, const void *vb) {
  const int *a = (const int *)va;
  const int *b = (const int *)vb;
  return CMP_INT(a[0], b[0]) * 27 + CMP_INT(a[1], b[1]) * 9 + CMP_INT(a[2], b[2]) * 3 + CMP_INT(a[1], b[1]);
}

static int cmp_int(const void *va, const void *vb) {
  const int a = *(int *)va;
  const int b = *(int *)vb;
  return (a > b) - (a < b);
}
static int cmp_tag(const void *va, const void *vb) {
  const long a = *(long *)va;
  const long b = *(long *)vb;
  return (a > b) - (a < b);
}
static int unique_tag(long *base, int n) {
  int n_uniq = 1;
  for (int i = 1; i < n; i++) {
    if (base[i] != base[i - 1]) {
      base[n_uniq++] = base[i];
    }
  }
  return n_uniq;
}
//bitonic sort without vectorization
static void sort_tag(long *tag, int n) {
  //asm volatile("uaddo $63, $63, 0, $63\n\t");
  int npad = 1;
  while (npad < n) npad <<= 1;
  for (int i = n; i < npad; i ++) {
    tag[i] = 0x7fffffffffffffffL;
  }
  for (int r0s = 1; (1 << r0s) <= npad; r0s ++) {
    int r0 = 1 << r0s;
    for (int r1 = r0 >> 1; r1 > 0; r1 >>= 1) {
      for (int iout = 0; iout < npad; iout += r0) {
        if (iout >> r0s & 1){
          for (int iin = iout; iin < iout + r0; iin += r1 + r1) {
            for (int i = iin; i < iin + r1; i ++) {
              long nmax = max(tag[i], tag[i + r1]);
              long nmin = min(tag[i], tag[i + r1]);
              tag[i] = nmax;
              tag[i + r1] = nmin;
            }
          }
        } else {
          for (int iin = iout; iin < iout + r0; iin += r1 + r1) {
            for (int i = iin; i < iin + r1; i ++) {
              long nmax = max(tag[i], tag[i + r1]);
              long nmin = min(tag[i], tag[i + r1]);
              tag[i] = nmin;
              tag[i + r1] = nmax;
            }
          }
        }
      }
    }
  }
}
INLINE int search_tag(celldata_t *cell, long tag) {
  for (int ig = 0; ig < cell->nguest; ig++) {
    if (tag == cell->tag[CELL_CAP + ig]) {
      return CELL_CAP + ig;
    }
  }
  for (int ig = 0; ig < cell->natom; ig++) {
    if (tag == cell->tag[ig]) {
      return ig;
    }
  }
  return 0x7fffffff;
}
INLINE long minimum_image(real d, real g, long tag, long local, long total) {
  real d0 = d;
  long ret = tag;
  while (d < -0.5 * g) {
    ret += local;
    d += g;
  }
  while (d > 0.5 * g) {
    ret -= local;
    d -= g;
  }
  while (ret < 0) ret += total;
  while (ret >= total) ret -= total;
  /* if (tag == 18699) { */
  /*   printf("processed 18699: %f %f %d\n", d0, d, ret); */
  /* } */

  return ret;
}
void process_pbc(long natoms, long nprocs, bond_graph_t *graph, long (*excls)[MAX_EXCLS_ATOM], long (*chain2)[MAX_EXCLS_ATOM][2], long (*scales)[MAX_SCALS_ATOM], double (*x)[3], vec<real> *glen){
  long natoms_total = natoms * nprocs;

  for (int i = 0; i < natoms; i ++){
    for (int jj = 0; excls[i][jj] != -1; jj ++){
      long j = excls[i][jj];
      real dx = x[j][0] - x[i][0];
      /* if (i == 18699) { */
      /* 	printf("dxe18699: %f\n", dx); */
      /* } */
      excls[i][jj] = minimum_image(dx, glen->x, excls[i][jj], natoms, natoms_total);
    }
    for (int jj = 0; scales[i][jj] != -1; jj ++){
      long j = scales[i][jj];
      real dx = x[j][0] - x[i][0];
      /* if (i == 18699) { */
      /* 	printf("dxs18699: %f\n", dx); */
      /* } */
      scales[i][jj] = minimum_image(dx, glen->x, scales[i][jj], natoms, natoms_total);
    }
    for (int jj = 0; chain2[i][jj][0] != -1; jj ++){
      long j0 = chain2[i][jj][0];
      long j1 = chain2[i][jj][1];
      chain2[i][jj][0] = minimum_image(x[j0][0] - x[i][0], glen->x, j0, natoms, natoms_total);
      chain2[i][jj][1] = minimum_image(x[j1][0] - x[i][0], glen->x, j1, natoms, natoms_total);
      /* if (i == 18699) { */
      /* 	printf("dxc18699: %f %f\n", x[j0][0] - x[i][0], x[j1][0] - x[i][0]); */
      /* } */

    }
    for (int jj = graph->first_bond[i]; jj < graph->first_bond[i + 1]; jj ++){
      long j = graph->bonded_tag[jj];
      /* if (i == 18699) { */
      /* 	printf("dxb18699: %f\n", x[j][0] - x[i][0]); */
      /* } */
      graph->bonded_tag[jj] = minimum_image(x[j][0] - x[i][0], glen->x, j, natoms, natoms_total);
    }
  }
}
#include "timer.h"
DEF_TIMER(FIND, "bonded/find bonds")
DEF_TIMER(SORT, "bonded/sort guests")
DEF_TIMER(SEARCH, "bonded/searching guest")
DEF_TIMER(REMAP, "bonded/remap structures")
INLINE void htag4(long *out, long in) {
  out[0] |= 1L << ((in >>  0) & 63L);
  out[1] |= 1L << ((in >>  4) & 63L);
  out[2] |= 1L << ((in >>  8) & 63L);
  out[3] |= 1L << ((in >> 12) & 63L);
}
INLINE int bmatch(long *filter, long in) {
  
  return
  filter[0] & (1L << ((in >>  0) & 63L)) &&
  filter[1] & (1L << ((in >>  4) & 63L)) &&
  filter[2] & (1L << ((in >>  8) & 63L)) &&
  filter[3] & (1L << ((in >> 12) & 63L));
}
/*
* find 1-2/1-3 exclusions, 1-4 scale, bonded back
*/
#ifdef __sw__
void find_bonds_sw(cellgrid_t *grid);
#endif
void find_bonds(cellgrid_t *grid) {
  return;
  #ifdef __sw__
  timer_start(FIND);
  find_bonds_sw(grid);
  timer_stop(FIND);
  return;
  #endif
  timer_start(FIND);
  int max_bonds = 0, max_chain2 = 0, max_excl = 0, max_scal = 0, max_impr = 0;
  int max_guest = 0, max_guest_uniq = 0;
  int nneighbor = hdcell(grid->nn, grid->nn, grid->nn, grid->nn);
  int self = hdcell(grid->nn, 0, 0, 0);
  int nsearch = 0, npass = 0, ntot = 0;
  // int excl_mask[MAX_EXCL_CELL];
  // FOREACH_CELL(grid, cx, cy, cz, cell) {
  //   cell->bfilter[0] = 0;
  //   cell->bfilter[1] = 0;
  //   cell->bfilter[2] = 0;
  //   cell->bfilter[3] = 0;
  //   /* We use 4x64bit mini bloom filters*/
  //   for (int i = 0; i < cell->natom; i ++) {
  //     htag4(cell->bfilter, cell->tag[i]);
  //   }
  // }
  /*loop cells*/
  FOREACH_LOCAL_CELL(grid, cx, cy, cz, icell) {
    // timer_start(SORT);
    long guest_tags[MAX_EXCL_CELL];
    //int guest_id[MAX_EXCLS_ATOM * CELL_CAP];
    int guest_bonds[MAX_EXCL_CELL][2];
    int nguest_full = 0;
    // for (int i = 0; i < icell->natom; i++) {
    //   for (int jj = icell->first_excl_atom[i]; jj < icell->first_excl_atom[i + 1]; jj++) {
    //     guest_tags[nguest_full++] = icell->excl_tag[jj]; //icell->excl_tag[i][jj];
    //   }
    // }
    // timer_start(SORT);
    // qsort(guest_tags, nguest_full, sizeof(long), cmp_tag);
    sort_tag(guest_tags, nguest_full);
    // timer_stop(SORT);
    int nguest_uniq = unique_tag(guest_tags, nguest_full);
    int nguest = 0;
    int nexcl = 0;
    int nscal = 0;
    long *guest_tag = icell->tag + CELL_CAP;
    real *guest_rmass = icell->rmass + CELL_CAP;
    int *guest_t = icell->t + CELL_CAP;
    ntot += nguest_uniq;
    // memset(excl_mask, 0, sizeof(excl_mask));
    // long bfilter_exclusions[4];
    // for (int i = 0; i < 4; i ++) bfilter_exclusions[i] = 0;

    FOREACH_NEIGHBOR(grid, cx, cy, cz, dx, dy, dz, jcell) {
      int did = hdcell(grid->nn, dx, dy, dz);
      icell->first_guest_cell[did] = nguest;
      icell->first_excl_cell[did] = nexcl;
      icell->first_scal_cell[did] = nscal;

      /*never invite atom in center cell as a guest*/
      if (dx != 0 || dy != 0 || dz != 0) {
        for (int ig = 0; ig < nguest_uniq; ig++) {
          //if (!bmatch(jcell->bfilter, guest_tags[ig])) continue;
          for (int j = 0; j < jcell->natom; j++) {
            if (jcell->tag[j] == guest_tags[ig]) {
              icell->guest_id[nguest] = j;
              /*rmass, type can be collected here since they do not change*/
              /*x and f should be collected during bonded and shake*/

              guest_tag[nguest] = jcell->tag[j];
              guest_t[nguest] = jcell->t[j];
              guest_rmass[nguest] = jcell->rmass[j];
              nguest++;
            }
          }
        }
      }

      if (hdcell(grid->nn, dx, dy, dz) > self) continue;
    }
    // timer_stop(SEARCH);
    // timer_start(REMAP);
    int orig_excl = 0;
    // for (int i = 0; i < icell->natom; i++) {
    //   /*if bond match*/
    //   for (int ie = icell->first_excl_atom[i]; ie < icell->first_excl_atom[i + 1]; ie++) {
    //     orig_excl++;
    //   }
    // }
    icell->first_scal_cell[nneighbor + 1] = nscal;
    icell->first_excl_cell[nneighbor + 1] = nexcl;
    // if (orig_excl != nexcl) {
    //   puts("missing exclusions, please check");
    // }
    icell->first_guest_cell[nneighbor + 1] = nguest;
    icell->nguest = nguest;
    assert(nguest < MAX_CELL_GUEST);
    int cell_bonds = 0, cell_chain2 = 0, cell_impr = 0;
    max_bonds = max(max_bonds, cell_bonds);
    max_chain2 = max(max_chain2, cell_chain2);
    max_excl = max(max_excl, nexcl);
    max_scal = max(max_scal, nscal);
    max_impr = max(max_impr, cell_impr);
  }
  int nmax[5], nmax_all[5];
  printf("%d %d %d\n", nsearch, npass, ntot);
  printf("max bonds=%d, max chain2=%d, max excl=%d, max scal=%d, max impr=%d\n", max_bonds, max_chain2, max_excl, max_scal, max_impr);
  // exit(1);
  timer_stop(FIND);
}

void bfs_excls(long natoms, long (*excls)[MAX_EXCLS_ATOM], long (*scales)[MAX_SCALS_ATOM], bond_graph_t *graph) {
  int *first_bond = graph->first_bond;
  long *bonded_tag = graph->bonded_tag;
  for (long i = 0; i < natoms; i++) {
    int nexcl2 = 0;
    /*copy bonded tags to 1-2 list*/
    for (int jj = first_bond[i]; jj < first_bond[i + 1]; jj++) {
      excls[i][nexcl2++] = bonded_tag[jj];
    }
    /*derive 1-3 list from 1-2 list*/
    int nexcl3 = nexcl2;
    for (int jj = 0; jj < nexcl2; jj++) {
      long j = excls[i][jj];
      for (int kk = first_bond[j]; kk < first_bond[j + 1]; kk++) {
        int found = 0;
        long k = bonded_tag[kk];
        /*j->i*/
        if (k == i)
          continue;
        /*remove duplicates*/
        for (int ll = 0; ll < nexcl3; ll++)
          if (excls[i][ll] == k) {
            found = 1;
            break;
          }
        if (found)
          continue;
        excls[i][nexcl3++] = k;
      }
    }
    int nscales = 0;
    /*derive 1-4 list from 1-3 list*/
    for (int jj = nexcl2; jj < nexcl3; jj++) {
      long j = excls[i][jj];
      for (int kk = first_bond[j]; kk < first_bond[j + 1]; kk++) {
        long k = bonded_tag[kk];
        if (k == i)
          continue;
        int found = 0;
        for (int ll = 0; ll < nexcl3; ll++)
          if (excls[i][ll] == k) {
            found = 1;
            break;
          }
        if (!found) {
          for (int ll = 0; ll < nscales; ll++)
            if (scales[i][ll] == k) {
              found = 1;
              break;
            }
          if (!found) {
            scales[i][nscales++] = k;
          }
        }
      }
    }
    excls[i][nexcl3] = -1;
    scales[i][nscales] = -1;
  }
}

void bfs_excls_chain2(long natoms, long (*excls)[MAX_EXCLS_ATOM], long (*chain2)[MAX_EXCLS_ATOM][2], long (*scales)[MAX_SCALS_ATOM], bond_graph_t *graph) {
  int *first_bond = graph->first_bond;
  long *bonded_tag = graph->bonded_tag;
  int max_nchain2 = 0, max_nexcls = 0, max_nscals = 0;
  for (long i = 0; i < natoms; i++) {
    int nexcl2 = 0;
    /*copy bonded tags to 1-2 list*/
    for (int jj = first_bond[i]; jj < first_bond[i + 1]; jj++) {
      excls[i][nexcl2++] = bonded_tag[jj];
    }
    /*derive 1-3 list from 1-2 list*/
    int nchain2 = 0;
    int nexcl3 = nexcl2;
    for (int jj = 0; jj < nexcl2; jj++) {
      long j = excls[i][jj];
      for (int kk = first_bond[j]; kk < first_bond[j + 1]; kk++) {
        int found = 0;
        long k = bonded_tag[kk];
        /*j->i*/
        if (k == i)
          continue;
        chain2[i][nchain2][0] = j;
        chain2[i][nchain2][1] = k;
        nchain2++;
        /*remove duplicates*/
        for (int ll = 0; ll < nexcl3; ll++)
          if (excls[i][ll] == k) {
            found = 1;
            break;
          }
        if (found)
          continue;
        excls[i][nexcl3++] = k;
      }
    }
    max_nchain2 = max(max_nchain2, nchain2);
    int nscales = 0;
    /*derive 1-4 list from 1-3 list*/
    for (int jj = nexcl2; jj < nexcl3; jj++) {
      long j = excls[i][jj];
      for (int kk = first_bond[j]; kk < first_bond[j + 1]; kk++) {
        long k = bonded_tag[kk];
        if (k == i)
          continue;
        int found = 0;
        for (int ll = 0; ll < nexcl3; ll++)
          if (excls[i][ll] == k) {
            found = 1;
            break;
          }
        if (!found) {
          for (int ll = 0; ll < nscales; ll++)
            if (scales[i][ll] == k) {
              found = 1;
              break;
            }
          if (!found) {
            scales[i][nscales++] = k;
          }
        }
      }
    }
    chain2[i][nchain2][0] = -1;
    excls[i][nexcl3] = -1;
    scales[i][nscales] = -1;
    max_nexcls = max(max_nexcls, nexcl3);
    max_nscals = max(max_nscals, nscales);
  }
  // printf("max nchain2=%d, max_nexcls=%d, max_nscals=%d\n", max_nchain2, max_nexcls, max_nscals);
}

void build_graph(bond_graph_t *graph, psf_data_t *psf) {
  int natoms = psf->natom;
  int nbonds = psf->nbond;
  int ndibonds = nbonds * 2;
  /*copy bond out*/
  
  int(*bond)[2] = esmd::allocate<int[2]>(nbonds * 2, "bonded/temp bonds");
  for (int i = 0; i < nbonds; i++) {
    bond[i][0] = psf->bond[i][0] - 1;
    bond[i][1] = psf->bond[i][1] - 1;
  }
  for (int i = nbonds; i < ndibonds; i++) {
    bond[i][0] = bond[i - nbonds][1];
    bond[i][1] = bond[i - nbonds][0];
  }
  /*sort bonds*/
  qsort(bond, ndibonds, sizeof(int) * 2, cmp_bond);
  int *first_bond = esmd::allocate<int>(natoms + 1, "bonded/first bond");
  long *bonded_tag = esmd::allocate<long>(ndibonds, "bonded/bonded tag");
  /*convert to neighbor list form*/
  int cur_atom = -1;
  for (long i = 0; i < ndibonds; i++) {
    if (bond[i][0] > cur_atom) {
      while (cur_atom < bond[i][0]) {
        cur_atom++;
        first_bond[cur_atom] = i;
      }
    }
    bonded_tag[i] = bond[i][1];
  }

  while (cur_atom < natoms) {
    cur_atom++;
    first_bond[cur_atom] = ndibonds;
  }
  esmd::deallocate(bond);
  /*save the list*/
  graph->first_bond = first_bond;
  graph->bonded_tag = bonded_tag;
}

void index_impr(impr_index_t *index, bond_graph_t *graph, psf_data_t *psf) {
  int *first_bond = graph->first_bond;
  long *bonded_tag = graph->bonded_tag;
  int natoms = psf->natom;
  int nimprs = psf->nimpr;
  int(*impr)[4] = esmd::allocate<int[4]>(nimprs, "bonded/temp imprs");
  /*copy out*/
  for (int i = 0; i < nimprs; i++) {
    impr[i][0] = psf->impr[i][0] - 1;
    impr[i][1] = psf->impr[i][1] - 1;
    impr[i][2] = psf->impr[i][2] - 1;
    impr[i][3] = psf->impr[i][3] - 1;
  }

  /*sort according to center atom*/
  qsort(impr, nimprs, sizeof(int) * 4, cmp_impr);

  int *first_impr = esmd::allocate<int>(natoms + 1, "bonded/first impr");

  int(*impr_bid)[3] = esmd::allocate<int[3]>(nimprs, "bonded/impr bondid");
  int cur_atom = -1;
  for (long i = 0; i < nimprs; i++) {
    while (impr[i][0] > cur_atom) {
      cur_atom++;
      first_impr[cur_atom] = i;
    }
    /*search over bonded tag, convert to index*/
    for (long j = first_bond[cur_atom]; j < first_bond[cur_atom + 1]; j++) {
      long jtag = bonded_tag[j];
      if (impr[i][1] == jtag)
        impr_bid[i][0] = j - first_bond[cur_atom];
      if (impr[i][2] == jtag)
        impr_bid[i][1] = j - first_bond[cur_atom];
      if (impr[i][3] == jtag)
        impr_bid[i][2] = j - first_bond[cur_atom];
    }
  }

  while (cur_atom < natoms) {
    cur_atom++;
    first_impr[cur_atom] = nimprs;
  }
  /*same as bond*/
  index->first_impr = first_impr;
  index->impr_bid = impr_bid;
  esmd::deallocate(impr);
}
